{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71eadd60",
   "metadata": {},
   "source": [
    "# Velocity Control Kinematic Simulation\n",
    "\n",
    "### Requirements\n",
    "Pip\n",
    "- numpy\n",
    "- scipy\n",
    "- matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d69cb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.spatial.transform import Rotation\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "plt.rcParams[\"animation.html\"] = \"jshtml\"\n",
    "plt.rcParams['figure.dpi'] = 150  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec1b8a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- HYPERPARAMS\n",
    "# Sim\n",
    "SPEEDUP = 4\n",
    "\n",
    "TIME_HORIZON = 12\n",
    "SIM_HZ = 240\n",
    "CONTROL_HZ = 20\n",
    "VIZ_HZ = 24 / SPEEDUP\n",
    "\n",
    "# Viz\n",
    "R_ROBOT = 0.2\n",
    "BUTT_SAMPLES = 20\n",
    "X_LIM = [-2, 2]\n",
    "Y_LIM = [-2, 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b44df483",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- SIMULATION ENV\n",
    "class Env:\n",
    "    def __init__(self, hz):\n",
    "        self.pos_state = np.zeros(3)\n",
    "        self.vel_state = np.zeros(2)\n",
    "        self.dt = 1.0 / hz\n",
    "        \n",
    "    def get_pose(self):\n",
    "        return self.pos_state\n",
    "    \n",
    "    def step(self, vel_input=None):\n",
    "        if vel_input is not None:\n",
    "            self.vel_state = vel_input\n",
    "            \n",
    "        self.pos_state[0] += self.vel_state[0] * np.cos(self.pos_state[2]) * self.dt\n",
    "        self.pos_state[1] += self.vel_state[0] * np.sin(self.pos_state[2]) * self.dt\n",
    "        self.pos_state[2] += self.vel_state[1] * self.dt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e86366",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- VISUALIZER\n",
    "class Viz:\n",
    "    def __init__(self):\n",
    "        self.frames = []\n",
    "        \n",
    "        # Robot local vertices\n",
    "        xy_ul = np.array([R_ROBOT / 2, R_ROBOT])\n",
    "        xy_ur = np.array([R_ROBOT / 2, -R_ROBOT])\n",
    "        xy_l = np.array([-R_ROBOT / 2, R_ROBOT])\n",
    "        xy_r = np.array([-R_ROBOT / 2, -R_ROBOT])\n",
    "        \n",
    "        xy_list = [xy_r, xy_ur, xy_ul, xy_l]\n",
    "        for phi in np.linspace(np.pi / 2, 3 * np.pi / 2, BUTT_SAMPLES):\n",
    "            xy_list.append(np.array([R_ROBOT * (np.cos(phi) - 0.5), R_ROBOT * np.sin(phi)]))\n",
    "            \n",
    "        self.xy_verts_local = np.array(xy_list)\n",
    "        \n",
    "    def _draw_robot(self, xyt, color):\n",
    "        theta = xyt[2]\n",
    "        ct = np.cos(theta)\n",
    "        st = np.sin(theta)\n",
    "        trans = np.array([[ct, -st], [st, ct]])\n",
    "        \n",
    "        xy_verts_transformed = (trans @ self.xy_verts_local.T).T + xyt[None, :2]\n",
    "        \n",
    "        self.ax.plot(xyt[0], xyt[1], color=color, marker='o')\n",
    "        self.ax.plot(xy_verts_transformed[:, 0], xy_verts_transformed[:, 1], color=color)\n",
    "        \n",
    "    def _render_frame(self, frame):\n",
    "        plt.cla()\n",
    "        \n",
    "        self.ax.set_xlim(*X_LIM)\n",
    "        self.ax.set_ylim(*Y_LIM)\n",
    "        self.ax.set_aspect(\"equal\")\n",
    "        \n",
    "        self._draw_robot(frame[1], '0.7')\n",
    "        self._draw_robot(frame[0], '0')\n",
    "        \n",
    "    def clear_frames(self):\n",
    "        self.frames = []\n",
    "    \n",
    "    def record_frame(self, xyt_robot, xyt_goal):\n",
    "        self.frames.append((xyt_robot, xyt_goal))\n",
    "        \n",
    "    def animate(self, playback_speed=1.0):  \n",
    "        plt.ioff()\n",
    "        fig, self.ax = plt.subplots()\n",
    "        \n",
    "        return animation.FuncAnimation(fig, self._render_frame, frames=self.frames, interval=1000 / VIZ_HZ / playback_speed, repeat=False)\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f72b7170",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- CONTROLLER\n",
    "from scipy.spatial.transform import Rotation\n",
    "\n",
    "V_MAX_DEFAULT = 0.2  # base.params[\"motion\"][\"default\"][\"vel_m\"]\n",
    "W_MAX_DEFAULT = 0.45  # (vel_m_max - vel_m_default) / wheel_separation_m\n",
    "ACC_LIN = 0.4  # 0.5 * base.params[\"motion\"][\"max\"][\"accel_m\"]\n",
    "ACC_ANG = 1.2  # 0.5 * (accel_m_max - accel_m_default) / wheel_separation_m\n",
    "MAX_HEADING_ANG = np.pi / 4\n",
    "\n",
    "def transform_global_to_base(XYT, current_pose):\n",
    "    \"\"\"\n",
    "    Transforms the point cloud into geocentric frame to account for\n",
    "    camera position\n",
    "    Input:\n",
    "        XYZ                     : ...x3\n",
    "        current_pose            : base position (x, y, theta (radians))\n",
    "    Output:\n",
    "        XYZ : ...x3\n",
    "    \"\"\"\n",
    "    XYT = np.asarray(XYT, dtype=np.float32)\n",
    "    new_T = XYT[2] - current_pose[2]\n",
    "    R = Rotation.from_euler(\"Z\", current_pose[2]).as_matrix()\n",
    "    XYT[0] = XYT[0] - current_pose[0]\n",
    "    XYT[1] = XYT[1] - current_pose[1]\n",
    "    out_XYT = np.matmul(XYT.reshape(-1, 3), R).reshape((-1, 3))\n",
    "    out_XYT = out_XYT.ravel()\n",
    "    return [out_XYT[0], out_XYT[1], new_T]\n",
    "\n",
    "class Controller:\n",
    "    def __init__(self, track_yaw=True):\n",
    "        self.track_yaw = track_yaw\n",
    "        \n",
    "        # Params\n",
    "        self.v_max = V_MAX_DEFAULT\n",
    "        self.w_max = W_MAX_DEFAULT\n",
    "        self.lin_error_tol = self.v_max / CONTROL_HZ\n",
    "        self.ang_error_tol = self.w_max / CONTROL_HZ\n",
    "        \n",
    "        # Init\n",
    "        self.xyt_goal = np.zeros(3)\n",
    "        self.dxyt_goal = np.zeros(3)\n",
    "    \n",
    "    def set_goal(self, goal, vel_goal=None):\n",
    "        self.xyt_goal = goal\n",
    "        if vel_goal is not None:\n",
    "            self.dxyt_goal = vel_goal\n",
    "    \n",
    "    def _compute_error_pose(self, xyt_base):\n",
    "        \"\"\"\n",
    "        Updates error based on robot localization\n",
    "        \"\"\"\n",
    "        xyt_err = transform_global_to_base(self.xyt_goal, xyt_base)\n",
    "        if not self.track_yaw:\n",
    "            xyt_err[2] = 0.0\n",
    "\n",
    "        return xyt_err\n",
    "    \n",
    "    @staticmethod\n",
    "    def _velocity_feedback_control(x_err, a, v_max):\n",
    "        \"\"\"\n",
    "        Computes velocity based on distance from target.\n",
    "        Used for both linear and angular motion.\n",
    "\n",
    "        Current implementation: Trapezoidal velocity profile\n",
    "        \"\"\"\n",
    "        t = np.sqrt(2.0 * abs(x_err) / a)  # x_err = (1/2) * a * t^2\n",
    "        v = min(a * t, v_max)\n",
    "        return v * np.sign(x_err)\n",
    "        \n",
    "    @staticmethod\n",
    "    def _turn_rate_limit(lin_err, heading_diff, w_max, tol=0.0):\n",
    "        \"\"\"\n",
    "        Compute velocity limit that prevents path from overshooting goal\n",
    "        \n",
    "        heading error decrease rate > linear error decrease rate\n",
    "        (w - v * np.sin(phi) / D) / phi > v * np.cos(phi) / D\n",
    "        v < (w / phi) / (np.sin(phi) / D / phi + np.cos(phi) / D)\n",
    "        v < w * D / (np.sin(phi) + phi * np.cos(phi))\n",
    "        \n",
    "        (D = linear error, phi = angular error)\n",
    "        \"\"\"\n",
    "        assert lin_err >= 0.0\n",
    "        assert heading_diff >= 0.0\n",
    "        \n",
    "        if heading_diff > MAX_HEADING_ANG:\n",
    "            return 0.0\n",
    "        else:\n",
    "            return w_max * lin_err / (np.sin(heading_diff) + heading_diff * np.cos(heading_diff) + 1e-5)\n",
    "    \n",
    "    def _feedback_traj_track(self, xyt_err):\n",
    "        xyt_err = self._compute_error_pose(xyt)\n",
    "        v_raw = V_MAX_DEFAULT * (K1 * xyt_err[0] + xyt_err[1] * np.tan(xyt_err[2])) / np.cos(xyt_err[2])\n",
    "        w_raw = V_MAX_DEFAULT * (K2 * xyt_err[1] + K3 * np.tan(xyt_err[2])) / np.cos(xyt_err[2])**2\n",
    "        v_out = min(v_raw, V_MAX_DEFAULT)\n",
    "        w_out = min(w_raw, W_MAX_DEFAULT)\n",
    "        return np.array([v_out, w_out])\n",
    "    \n",
    "    def _feedback_simple(self, xyt_err):\n",
    "        v_cmd = w_cmd = 0\n",
    "\n",
    "        lin_err_abs = np.linalg.norm(xyt_err[0:2])\n",
    "        ang_err = xyt_err[2]\n",
    "\n",
    "        # Go to goal XY position if not there yet\n",
    "        if lin_err_abs > self.lin_error_tol:\n",
    "            heading_err = np.arctan2(xyt_err[1], xyt_err[0])\n",
    "            heading_err_abs = abs(heading_err)\n",
    "\n",
    "            # Compute linear velocity\n",
    "            v_raw = self._velocity_feedback_control(\n",
    "                lin_err_abs, ACC_LIN, self.v_max\n",
    "            )\n",
    "            v_limit = self._turn_rate_limit(\n",
    "                lin_err_abs,\n",
    "                heading_err_abs,\n",
    "                self.w_max / 2.0,\n",
    "                tol=self.lin_error_tol,\n",
    "            )\n",
    "            v_cmd = np.clip(v_raw, 0.0, v_limit)\n",
    "\n",
    "            # Compute angular velocity\n",
    "            w_cmd = self._velocity_feedback_control(\n",
    "                heading_err, ACC_ANG, self.w_max\n",
    "            )\n",
    "\n",
    "        # Rotate to correct yaw if yaw tracking is on and XY position is at goal\n",
    "        elif abs(ang_err) > self.ang_error_tol and self.track_yaw:\n",
    "            # Compute angular velocity\n",
    "            w_cmd = self._velocity_feedback_control(\n",
    "                ang_err, ACC_ANG, self.w_max\n",
    "            )\n",
    "            \n",
    "        return v_cmd, w_cmd\n",
    "\n",
    "    def forward(self, xyt):\n",
    "        xyt_err = self._compute_error_pose(xyt)\n",
    "        return self._feedback_simple(xyt_err)\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6b21681",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- GOAL\n",
    "GOAL = [0.2, 0.5, -0.3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c25b035c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- MAIN\n",
    "# initialize objects\n",
    "xyt_goal = np.array(GOAL)\n",
    "\n",
    "viz = Viz()\n",
    "env = Env(SIM_HZ)\n",
    "agent = Controller()\n",
    "agent.set_goal(xyt_goal)\n",
    "\n",
    "# simulate\n",
    "t_control = 0\n",
    "t_viz = 0\n",
    "dt_control = 1 / CONTROL_HZ\n",
    "dt_viz = 1 / VIZ_HZ\n",
    "for t in np.linspace(0, TIME_HORIZON, int(TIME_HORIZON * SIM_HZ)):\n",
    "    xyt = env.get_pose()\n",
    "    \n",
    "    vel_command = None\n",
    "    if t >= t_control:\n",
    "        vel_command = agent.forward(xyt)\n",
    "        t_control += dt_control\n",
    "    if t >= t_viz:\n",
    "        viz.record_frame(xyt.copy(), xyt_goal.copy())\n",
    "        t_viz += dt_viz\n",
    "    \n",
    "    env.step(vel_command)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "128d46ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# animate\n",
    "viz.animate(playback_speed=SPEEDUP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d15f98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
